## File Description:
# Purpose: Main CER Electricity Trial Analysis
# Author: Brian Prest
# This file does the following:
# 1) Read in and clean data
# 2) Hourly treatment effects
# 3) Balance tests
# 4) Apply Causal Tree Algorithm
#   To install causalTree, see https://github.com/susanathey/causalTree
#   Note: This was run in R version 3.4.1. 
#   causalTree may not be compatible more recent versions of R
# 5) Compare to Theory-Driven Approach to Heterogeneity
# 6) Evaluate Predictibility of Awareness (Theory-Driven and with LASSO)
library(data.table)
library(zoo)
library(lubridate)
library(lfe)
library(dummies)
library(causalTree)
library(xlsx)
library(stargazer)
library(xtable)
library(tree)
library(randomForest)

# Concatenating function
"%&%" <- function(x,y) paste0(x,y)

# Set directories
working_dir = 'your/working/directory/Replication Code/'
data_dir = working_dir%&%'/Data/'
output_dir = working_dir%&%'/Output/'
graph_dir = output_dir%&%'/Graphs/'

options(digits=7)
# write.out = FALSE
write.out = TRUE

##### Some useful functions

getsize = function(object, scaling=2^20) {
        # Returns a vector of object sizes, including the total, in MB (2^20)
        obs = sort( sapply(object,function(x) object.size(get(x)) ))/scaling
        tot = sum(obs)
        return(round(c(obs, total=tot), 1))
}
getsize(ls())

clear.fe = function(fit) {
    # Deletes un-needed objects from an FE fit that use lots of memory
    fit$response <- NULL
    fit$fitted.values <- NULL
    fit$residuals <- NULL
    fit$r.residuals <- NULL
    fit$cfactor <- NULL  
    fit$cX <- NULL
    fit$cY <- NULL  
    fit$X <- NULL 
    fit$clustervar <- NULL
    return(fit)
}

#### 1) Read in and Clean CER Data -----
## Read in data
import <- function(...) {
        message("importing consumption data...")
        if(dir.exists(data_dir)) {
                files <- list.files(data_dir%&%'/CER Electricity Data/ConsumpData/', pattern = "^File.*txt$", full.names = T)
                dts <- lapply(files, fread, sep = " ") # loop through each path and run 'fread' (data.tables import)
                DT <- rbindlist(dts) # stack data
                rm('dts')
                setnames(DT, names(DT), c('id', 'date_cer', 'kw')) # rename data
                setkey(DT, id, date_cer) # set key to 'id'
        } else {
                if(file.exists(extdata)) {
                        cmd <- paste('zcat <', extdata)
                        DT <- fread(input = cmd)
                } else {
                        stop('No CER residential consumption data source')
                }
        }
        return(DT)
}

dt <- import()

find_seq_zeros_nas <- function(DT_KWH, tol = 10, drop = TRUE) {
  message('searching for long strings of zeros or nas...')
  indx <- DT_KWH[is.na(kwh) | kwh==0]
  indx[, hour_cer := date_cer %%100]
  indx[, day_cer  := (date_cer - hour_cer)/100]
  sq <- indx[, .(ddate=c(11111,diff(date_cer))), by = 'id']
  sq1 <- indx[, .(dday=c(11111,diff(day_cer))), by = 'id']
  sq[, dday:=sq1$dday]
  sq[, s:=0 + (ddate>1)]
  sq[ddate==53 & dday==1, s:=0] # find sequences through diff days
  sq[, s:=cumsum(s)]
  sq <- sq[, .N, by =c('id', 's')]
  if(drop) {
    drop_ids <- unique(sq[N > tol]$id)
    DT1 <- DT_KWH[!(DT_KWH$id %in% drop_ids)]
    message('dropped ', length(drop_ids), ' ids.')
    sq <- sq[N <= tol]
    n <- uniqueN(DT1$id)
    message(n, ' ids remain with sequential zeros or nas less than ', tol, ' in length.')
  } else {
    n <- uniqueN(sq$id)
    message(n, ' ids found with sequential zeros.')
  }
  return(list(DT1, sq))
}

## Import Treatment/Control Assignments ----
dt_assign <- fread(data_dir%&%'CER Electricity Data/assignments.csv', sep = ',', select = c(1:4), na.strings = c("NA", "", "."))
setkey(dt_assign, tariff)
dt_assign["b", tariff:="B"] # fix lowercase b's
setkey(dt_assign, id)
dt_assign <- dt_assign[code == 1] # subset the data to residential only
dt_assign[, tar_stim := paste0(tariff, stimulus)]
dt_assign[, `:=`(code=NULL)] # drop redundant vars
dt_assign = dt_assign[stimulus!='W'] # drop weekend tariff group

## Merge DT with treatment group
dt = merge(dt, dt_assign, by='id')
dt[, any.trt:=1]
dt[tar_stim=='EE', any.trt:=0]
getsize(ls())

# convert to kwh. By default is in kw (i.e., if running at 1 kw constantly, will consume 24 kwh =
# 1 kw * 24 hrs. But at the halfhourly level it will be 1 kw * 48 half hours * 0.5 hours per half hour
# Taking averages doesn't screw anything up, but summing across half hours will 
# (e.g., 48 obs of 1 kw average to 1 but sum to 48, when we want a sum of 24).
dt[, kwh:=kw*0.5]
dt[, kw:=NULL] # get rid of this so it's not confusing

dt.2 = find_seq_zeros_nas(dt, tol=48) # 1 day of zeros is ok.
dt = dt.2[[1]]
rm(dt.2)
gc()

# Convert to meaningful dates
dt[, day_cer:=floor(date_cer/100)] 
dates = dt[, unique(day_cer)]
dt.dates = data.table(day_cer=dates, date=as.Date('2009-01-01')+days(dates-1))
dt.dates[, trt.period := ifelse(date>=as.Date('2010-01-01'), 1,  0)]
dt.dates[, mon:=as.factor(cut.Date(date, breaks='months'))]
dt.dates[, weekday := ifelse(wday(date) %in% c(1,7), 0, 1)]
dt = merge(dt, dt.dates, by='day_cer')
rm(dates); rm(dt.dates)
gc()

# Create half-hourly variable
dt[, half.hr:= date_cer %% 100] 

# Drop redundant vars
dt[, `:=`(date_cer=NULL, day_cer=NULL)] 
gc()

# Relevel factor variables so control is the base level
dt[, tar_stim:=relevel(as.factor(tar_stim), ref='EE')]

# Bank holidays: New Years, St. Patricks, Easter, May, June, August, Halloween, Xmas, St. Stephens
# http://www.officeholidays.com/countries/ireland/2010.php
bank.holidays.2009 = '2009-' %&% c('01-01', '03-17', '04-13', '05-04', '06-01', '08-03', '10-26', '12-25', '12-28')
bank.holidays.2010 = '2010-' %&% c('01-01', '03-17', '04-05', '05-03', '06-07', '08-02', '10-25', '12-25', '12-27')
bank.holidays = as.Date(c(bank.holidays.2009, bank.holidays.2010))
rm(list=c('bank.holidays.2009', 'bank.holidays.2010'))
dt[date %in% bank.holidays, weekday := 0]

# Keep only weekdays & non-bank holidays (when TOU was in effect)
dt.nonweekday = copy(dt[weekday==0])
dt = dt[weekday==1]
dt[, weekday:=NULL]
setkey(dt, 'id')
gc()
getsize(ls())

# Bring in hourly electricity prices
rates = CJ(tariff=c('A','B','C','D','E'), period=c('Day','Night','Peak'))
rates[tariff=='A', rate:=c(14, 12, 20)]
rates[tariff=='B', rate:=c(13.5, 11, 26)]
rates[tariff=='C', rate:=c(13, 10, 32)]
rates[tariff=='D', rate:=c(12.5, 9, 38)]
rates[tariff=='E', rate:=c(14.1, 14.1, 14.1)]

night.periods = c( (23*2+1):(24*2), 1:(8*2) ) # 23:00 - 8:00 (night)
day.periods = (8*2+1):(17*2)  # 8:00 - 17:00 (day 1)
peak.periods = (17*2+1):(19*2) # 17:00 - 19:00 (peak)
day.periods = c(day.periods, (19*2+1):(23*2) ) # 19:00 - 23:00 (day 2)

dt[half.hr %in% night.periods, period:='Night']
dt[half.hr %in% day.periods, period:='Day']
dt[half.hr %in% peak.periods, period:='Peak']

dt = merge(dt, rates, by=c('tariff','period'), all.x=TRUE, all.y=FALSE)

# For baseline period, rate was 14.1 for everyone.
# Don't put this in the data because we will want to compute counterfactual TOU costs
# for treatment group during baseline period to find "structural winners"

# Drop tariff and stimulus variables
dt[, `:=`(tariff=NULL, stimulus=NULL)] 
gc()
getsize(ls())

### Read in Survey Data ----
## Pre-trial survey ----
survey.pre = fread(data_dir%&%'CER Electricity Data/cer_survey_pre.csv', skip=1,  na.strings = c("NA", "", "."))
# Find variables that have missing values keyed as something like 999
sapply(survey.pre, function(x) sum(x>99, na.rm=TRUE))
which(sapply(survey.pre, function(x) sum(x>=9, na.rm=TRUE))>0)
survey.pre[n_pct_already_saved>100 ,n_pct_already_saved:=NA] # percents larger than 100, mostly 999s
survey.pre[n_year_built>2010 ,n_year_built:=NA] #  9999s
survey.pre[n_home_size==999999999 ,n_home_size:=NA] # 999999999
# 3.28084 feet per meter -> 3.28084^2 square feet per square meter
survey.pre[, n_home_size_adj := as.numeric(n_home_size)]
survey.pre[f_home_size_mtrs_or_ft==1, n_home_size_adj:=as.numeric(n_home_size)*(3.28084^2)]

# "Refused" -> NA
survey.pre[f_age==7, f_age:=NA]
survey.pre[f_social_class==6, f_social_class:=NA]
survey.pre[f_home_style==6, f_home_style:=NA]
survey.pre[n_bedrooms==6, n_bedrooms:=NA]
survey.pre[f_education==6, f_education:=NA]
survey.pre[f_hhincome1==6, f_hhincome1:=NA]
survey.pre[f_hhincome2==6, f_hhincome2:=NA]

# Income question asked twice: first as a # and then as a category.
survey.pre[, f_hhincome := f_hhincome1]
survey.pre[is.na(f_hhincome), f_hhincome := f_hhincome2]

# In these questions, 8 maps to "None"
survey.pre[n_adults_at_home_during_day==8, n_adults_at_home_during_day:=0]
survey.pre[n_kids_at_home_during_day==8, n_kids_at_home_during_day:=0]

# f_hhincome_per_what: 1 (weekly), 2 (monthly), 3 (yearly)
survey.pre[, table(f_hhincome, f_hhincome_per_what, exclude=NULL)]
# It's mostly yearly. Some said weekly, but scaling this up by 52 would probably be inappropriate
survey.pre[f_hhincome_per_what!=3, f_hhincome:=NA]
survey.pre[, table(f_hhincome, f_hhincome_per_what, exclude=NULL)]

## Read in post-trial survey ----
# Note: I changed the duplicate id 5059 in this file to 5061, which is the next entry in the pre-trial
# survey spreadsheet.
survey.post = fread(data_dir%&%'CER Electricity Data/cer_survey_post.csv', skip=1,  na.strings = c("NA", "", "."))
which(sapply(survey.post, function(x) sum(x>=9, na.rm=TRUE))>0)

# "Don't know" peak tariff value -> NA
survey.post[n_reported_peak_tariff==999, n_reported_peak_tariff:=NA]
survey.post[n_reported_night_tariff==999, n_reported_night_tariff:=NA]
survey.post[n_reported_shoulder_tariff==999, n_reported_shoulder_tariff:=NA]

# In these questions, 8 maps to "None"
survey.post[n_adults_at_home_post==8, n_adults_at_home_post:=0]
survey.post[n_kids_at_home_post==8, n_kids_at_home_post:=0]

# "Refused" -> NA
survey.post[f_age_post==7, f_age_post:=NA]
survey.post[f_social_class_post==6, f_social_class_post:=NA]

#### Merge surveys together ----
survey = merge(survey.pre, survey.post, by='id')
# survey[, id:=as.factor(id)]
survey = merge(survey, dt_assign, by='id', all=FALSE)

dim(survey)
# How many variables are 100% complete?
sum(sapply(survey, function(x) round(sum(!is.na(x)),2)*100)/survey[, .N]==100)
# Just id
which(sapply(survey, function(x) round(sum(!is.na(x)),2)*100)/survey[, .N]==100, useNames = T)

# How many variables at >70% complete?
sum(sapply(survey, function(x) round(sum(!is.na(x)),2)*100/survey[, .N]>70))
which(sapply(survey, function(x) round(sum(!is.na(x)),2)*100)/survey[, .N]>70, useNames = T)

## Adjustments for Coding of Survey Data ----
# Household Characteristics
gender.key = c('1'='Male','2'='Female')
age.key = c('1'='18-25','2'='26-35','3'='36-45','4'='46-55','5'='56-65','6'='65+','7'='Refused')
empl.key = c('1'='Employed','2'='Self-Employed with Employees','3'='Self-Employed',
             '4'='Unemployed (seeking)','5'='Unemployed (not seeking)','6'='Retired','7'='Carer')
class.key = c('1'='AB','2'='C1','3'='C2','4'='DE','5'='Farmer','6'='Refused')
educ.key = c('1'='None','2'='Primary','3'='Secondary to Cert','4'='Secondary without Cert','5'='Third','6'='Refused')
own.key = c('1'='Rent (pvt)','2'='Rent (public)','3'='Own outright','4'='Own with mortgage','5'='Other')
style.key = c('1'='Apartment','2'='Semi-detached','3'='Detached','4'='Terraced','5'='Bungalow','6'='Refused')
home.age.key = c('1'='[0,5) years','2'='[5,10) years','3'='[10,30) years','4'='[30,75) years','5'='75plus years')
cook.key = c('1'='ElectricCook', '2'='GasCook', '3'='OilCook', '4'='SolidCook')
survey[, fs_gender:=gender.key[f_gender]]
survey[, fs_gender_post:=gender.key[d_gender_post]]
survey[, fs_age:=age.key[f_age]]
survey[, fs_employment_status:=empl.key[f_employment_status]]
survey[, fs_employment_status_post:=empl.key[f_employment_post]]
survey[, fs_social_class:=class.key[f_social_class]]
survey[, fs_education:=educ.key[f_education]]
survey[, fs_own_or_rent:=own.key[f_own_or_rent]]
survey[, fs_home_style:=style.key[f_home_style]]
survey[, fs_approx_home_age:=home.age.key[f_approx_home_age]]
survey[, fs_cook_type:=cook.key[f_cook_type]]

# Compute approx. home ages based on year built
survey[, est_home_age:=2009-n_year_built]
survey[is.na(fs_approx_home_age) & between(est_home_age, 0, 4.9999), fs_approx_home_age:=home.age.key[1]]
survey[is.na(fs_approx_home_age) & between(est_home_age, 5, 9.9999), fs_approx_home_age:=home.age.key[2]]
survey[is.na(fs_approx_home_age) & between(est_home_age, 10, 29.9999), fs_approx_home_age:=home.age.key[3]]
survey[is.na(fs_approx_home_age) & between(est_home_age, 30, 74.9999), fs_approx_home_age:=home.age.key[4]]
survey[is.na(fs_approx_home_age) & between(est_home_age, 75, 9999999), fs_approx_home_age:=home.age.key[5]]
survey[, est_home_age:=NULL]

# Delete the original variable to avoid collinearity
del.vars = 'f_'%&%c('gender','age', 'employment_status','employment_post','social_class','education',
                    'own_or_rent','home_style','approx_home_age','cook_type')
survey[, (del.vars):=NULL]
survey[, d_gender_post:=NULL] # duplicate variable

# Also delete certain post-trial survey questions that we wouldn't have information during policy design, 
# or are endogenous to treatment
del.vars = c('n_save_energy_bill_post','n_save_energy_environ_post','n_can_save_energy_post','n_already_done_alot_post',
             'n_already_made_changes_post','n_want_to_do_more_post','n_know_how_to_save_post','n_cannot_control_elec_post',
             'n_reducing_inconvenient_post','n_do_not_know_appliances_post','n_cannot_control_others_post','n_not_enough_time_post',
             'n_dont_want_to_be_told_post','n_reducing_is_futile_post','f_change_in_awareness_post','f_heat_type_post',
             'd_like_cold_post','d_cant_afford_heat_post','d_not_well_insulated_post','d_why_not_warm_other_post',
             'd_forgo_heat_lack_money_post','d_forgo_heat_cold_day_post','d_forgo_heat_bedtime_post',
             'd_forgo_heat_delay_post','n_difficulty_of_reducing_post', 'n_i_need_to_reduce_post')
survey[, (del.vars):=NULL]
# Variables indicating peoples stated reaction
# del.vars = grep('affected', names(survey), v=T)
dim(survey)

# Convert these to dummy variables (rpart has trouble on CV with factor variables apparently)
add.dummy = function(x, var) {
  x.dums = dummy(x[[var]])
  x.name = deparse(substitute(x))
  colnames(x.dums) = gsub(x.name, var%&%'.', colnames(x.dums))
  x = cbind(x[, -var, with=FALSE], x.dums)  
  return(x)
}
fs.vars = c('fs_gender','fs_gender_post','fs_age','fs_employment_status','fs_employment_status_post',
            'fs_social_class','fs_education','fs_own_or_rent','fs_home_style','fs_approx_home_age','fs_cook_type')
for (v in fs.vars) survey = add.dummy(survey, v)

# Fix dummies to be 0/1 instead of 1/2
survey[d_aware_of_tariff_change==2, d_aware_of_tariff_change:=0] # 1=Yes, 2=No
# set.seed(1)
# survey[tar_stim=='EE', rbinom(.N, 1, prob=0.5)]
survey[f_internet==2, f_internet:=0] # 1=Yes, 2=No
survey[f_use_internet==2, f_use_internet:=0] # 1=Yes, 2=No
survey[f_others_internet==2, f_others_internet:=0] # 1=Yes, 2=No
survey[f_has_kids==1, d_has_kids:=0] # 1=Live alone, 2=Multiple adults>=15, 3=Has kids<15
survey[f_has_kids==2, d_has_kids:=0] 
survey[f_has_kids==3, d_has_kids:=1] 
survey[d_heater_timer==2, d_heater_timer:=0] # 1=Yes, 2=No

# f_when_attic_insulated asks "when was your attic insulated?" 1/2 are <5y/>5y. 3 is No. 4 is Don't Know
survey[f_when_attic_insulated %in% c(1,2), d_attic_insulated:=1]
survey[f_when_attic_insulated==3, d_attic_insulated:=0]
# d_walls_insulated 1=Yes, 2=No, 3=DK
survey[d_walls_insulated==3, d_walls_insulated:=NA] # 3 is Don't Know
survey[d_walls_insulated==2, d_walls_insulated:=0] # 1=Yes, 2=No
survey[d_hot_water_lagging_jacket==2, d_hot_water_lagging_jacket:=0] # 1=Yes, 2=No

survey[d_has_kids==0, n_kids:=0] # This is "I live alone" or "All adults"
survey[f_has_kids==1, n_adults:=1] # This category is "I live alone". They didn't ask the n_adults question if this was reported
survey[, f_has_kids:=NULL]
survey[, n_residents := n_adults+n_kids]

# All the appliances coded 1 as "None", 2 as 1 machine, 3 as 2+ machines.
survey[, `:=`(n_washing_machine=n_washing_machine-1, n_tumble_dryer=n_tumble_dryer-1, 
              n_dishwasher=n_dishwasher-1, n_immersion=n_immersion-1,
              n_elec_shower_instant=n_elec_shower_instant-1, n_elec_shower_pumped=n_elec_shower_pumped-1,
              n_elec_cook=n_elec_cook-1, n_elec_heater_appliance_plugin=n_elec_heater_appliance_plugin-1,
              n_standalone_freezer=n_standalone_freezer-1, n_water_pump=n_water_pump-1)]

# Aggregate the two kinds of electric showers.
survey[, n_elec_shower := n_elec_shower_instant + n_elec_shower_pumped]   

# Sum total number of appliances.
survey[, n_appliances := n_washing_machine+n_tumble_dryer+n_dishwasher+n_immersion+
               n_elec_shower_instant+n_elec_shower_pumped+n_elec_cook+n_elec_heater_appliance_plugin+
               n_standalone_freezer+n_water_pump]

# All the electronics coded 1 as "None", 2 as 1 machine, 3 as 2+ machines.
survey[, `:=`(n_tv_small=n_tv_small-1, n_tv_large=n_tv_large-1, n_desktop_computer=n_desktop_computer-1,
              n_laptop_compter=n_laptop_compter-1, n_game_console=n_game_console-1)]
# Sum total number of electronics
survey[, n_electronics := n_tv_small + n_tv_large + n_desktop_computer + n_laptop_compter + n_game_console]

## Check for balance ---
# Bring baseline consumption/peak consumption into the survey for balance checks.
bal.vars = c('n_residents', 'n_adults', 'n_kids', 'n_bedrooms', 'd_heat_elec', 'd_elec_cook',
             'n_appliances', 'n_electronics')

which(names(survey) %in% bal.vars)
survey[, lapply(.SD, mean, na.rm=TRUE), by=tar_stim=='EE', .SDcols=which(names(survey) %in% bal.vars)]

dt.baseline.cons = dt[trt.period==0, .(kwh=mean(kwh)), by=id]
survey = merge(survey, dt.baseline.cons, by='id', all.x=TRUE)
dt.baseline.peak.cons = dt[trt.period==0 & period=='Peak', .(peak.kwh=mean(kwh)), by=id]
survey = merge(survey, dt.baseline.peak.cons, by='id', all.x=TRUE)

# Drop the small number of households without energy use data
survey = survey[complete.cases(kwh)]

# Convert treatment group to factor
survey[, tar_stim:=as.factor(tar_stim)]
survey[, tar_stim:=relevel(tar_stim, ref='EE')]

# Check Balance on Baseline Consumption (Table 3):
summary(lm(peak.kwh ~ tar_stim, data=survey))
summary(lm(kwh ~ tar_stim, data=survey))
summary(lm(peak.kwh ~ tar_stim=='EE', data=survey))
summary(lm(kwh ~ tar_stim=='EE', data=survey))

# Cleaned survey data on ~3000 households
dim(survey)
dt

## Output table of treatment/control distribution ----
trt.ctrl.table = table(survey[, .(stimulus, tariff)])
colnames(trt.ctrl.table) = paste("Tariff", rownames(trt.ctrl.table))
colnames(trt.ctrl.table)[5] = 'Control'
rownames(trt.ctrl.table) = c('Bi-monthly bill and energy usage statement', 'Monthly bill and energy usage statement',
                             'Bi-monthly bill, energy usage statement, and in-home display', 
                             'Bi-monthly bill, energy usage statement, and Overall Load Reduction incentive',
                             'Control')
trt.ctrl.table = rbind(trt.ctrl.table, colSums(trt.ctrl.table))
trt.ctrl.table = cbind(trt.ctrl.table, rowSums(trt.ctrl.table))
colnames(trt.ctrl.table)[6] = 'Total'
rownames(trt.ctrl.table)[6] = 'Total'

# Treatment/control distribution (Table 1)
xtable(trt.ctrl.table, digits=0)

## Check difference in group responses by in survey/not in survey ----
# We're going to drop the households without survey data, but first,
# what is the average treatment effect among ALL households,
# regardless of survey data?
# And is it statistically different?
dt.all.hh = dt[period=='Peak', .(kwh=mean(kwh)), by=.(id, any.trt, trt.period)] # mean kwh by peak
# Cast to wide so we can compute change in consumption
dt.all.hh = dcast(dt.all.hh, id+any.trt ~ ifelse(trt.period==1,'trt.kwh','base.kwh'), value.var=c('kwh'))
dt.all.hh[, d.kwh:=(trt.kwh-base.kwh)/base.kwh]

dt[, uniqueN(id)] # number in data
survey[, .N] # number who completed surveys
survey[, .N]/dt[, uniqueN(id)] # share who completed the surveys

# Determine who is in survey data
dt.all.hh[, insurvey:=0]
dt.all.hh[id %in% survey[, unique(id)], insurvey:=1]

# Treatment effects by in survey/not in survey 
summary(lm(d.kwh ~ any.trt, data=dt.all.hh)) # all households
summary(lm(d.kwh ~ any.trt, data=dt.all.hh[insurvey==1])) # in survey
summary(lm(d.kwh ~ any.trt*insurvey, data=dt.all.hh)) # interaction

## Now, only keep dt data for which we have survey data observations ---
# This saves on memory.
dt = dt[id %in% survey[, unique(id)]]
dt.nonweekday = dt.nonweekday[id %in% survey[, unique(id)]]
save(dt, file=output_dir%&%'/consumption_panel.Rdata')
save(dt.nonweekday, file=output_dir%&%'/consumption_panel_nonweekday.Rdata')
getsize(ls()); gc()

## Plot of tariffs' rate schedules (Figure 1) ----
dev.off()
x = data.table(half.hr=1:48)
x[, hour:=half.hr/2]
x = merge(x, dt[id==1002 & date=='2010-01-04', .(half.hr, period)], by='half.hr')
x = rbind(x, x, x, x, x)
x[, tariff := unlist(lapply(c('A','B','C','D','E'), function(x) rep(x, each=48)))]
x = merge(x, rates, by=c('tariff','period'))
x = x[order(tariff, half.hr)]

xgrid = seq(0, 24, by=2)
my.cols = rainbow(4, start=0.5, end=1)

if (write.out==TRUE) pdf(file=graph_dir%&%'/Tariff-Structure-'%&%today()%&%'.pdf', family='serif', width=10, height=7)
par(mfrow=c(1,1), family='serif', mai=c(0.85,0.85,0.4,0.15)) # for pdf output
# par(mfrow=c(1,1), mai=c(0.5,0.5,0.25,0.15)) # for pdf output
plot(rate ~ I(hour-0.5), data=x[tariff=='E'], type='l', col='black', lty=1, lwd=2, main='',
     xlab='Hour of Day', ylab='Price (� cents per kWh)', ylim=c(0,40), xaxt='n', cex.lab=1.5)
axis(1, xgrid, cex.axis=1.5); abline(v=xgrid, col='gray70', lty=3); grid(col='gray70', nx=NA, ny=NULL)
# abline(v=c(17, 19), col='red', lty=2, lwd=2); abline(v=c(8, 23), col='darkblue', lty=2, lwd=2)
lines(rate ~ I(hour-0.5), type='s', data=x[tariff=='D'], col=my.cols[4], lwd=5)
lines(rate ~ I(hour-0.5), type='s', data=x[tariff=='C'], col=my.cols[3], lwd=4)
lines(rate ~ I(hour-0.5), type='s', data=x[tariff=='B'], col=my.cols[2], lwd=3)
lines(rate ~ I(hour-0.5), type='s', data=x[tariff=='A'], col=my.cols[1], lwd=2)

for (tar in c('A','B','C','D','E')) {
  text(x=18, y=rates[tariff==tar & period=='Peak',rate]+1, label=rates[tariff==tar & period=='Peak',rate])
}
legend(x=6, y=40, legend=rev(c('Control (No Change)','Tariff A','Tariff B','Tariff C','Tariff D')),  bty='n',
       lty=1, col=rev(c('Black',my.cols)), lwd=rev(c(2,2,2:5)), cex=1.5)

dev.off()

#### 2) Run all analysis that requires half-hourly resolution ------
## Graph half-hourly average consumption by {trt, ctrl}x{baseline period, trt period}
## (Figure 3)
dt.halfhr.avg = dt[, .(kwh=mean(kwh),uniqueN(id)), by=.(half.hr, any.trt=1*(tar_stim!='EE'), trt.period)]
dt.halfhr.avg = dt.halfhr.avg[order(half.hr, any.trt, trt.period)]
gc()

xgrid = seq(0, 24, by=2)
dev.off()
if (write.out==TRUE) pdf(file=graph_dir%&%'/Avg-HalfHourly-Consumption-by-Trt-Group-and-Period-'%&%today()%&%'.pdf', family='serif', width=10, height=7)
par(mfrow=c(1,1), family='serif', mai=c(0.85,0.85,0.4,0.15)) # for pdf output
plot(kwh ~ I(half.hr/2), data=dt.halfhr.avg, type='n', ylim=c(0, 0.5),
     ylab='Consumption (kWh per 30 minutes)', xlab='Hour of Day', xaxt='n', cex.lab=1.5, cex.axis=1.5)
axis(1, xgrid, cex.axis=1.5); abline(v=xgrid, col='gray70', lty=3); grid(col='gray70', nx=NA, ny=NULL)
lines(kwh ~ I(half.hr/2-0.25), data=dt.halfhr.avg[any.trt==0 & trt.period==0], lwd=3, lty=1, col='gray40', type='l')
lines(kwh ~ I(half.hr/2-0.25), data=dt.halfhr.avg[any.trt==0 & trt.period==1], lwd=3, lty=2, col='gray40', type='l')
lines(kwh ~ I(half.hr/2-0.25), data=dt.halfhr.avg[any.trt==1 & trt.period==0], lwd=3, lty=1, col='green', type='l')
lines(kwh ~ I(half.hr/2-0.25), data=dt.halfhr.avg[any.trt==1 & trt.period==1], lwd=3, lty=2, col='green', type='l')
abline(v=c(17, 19), col='darkred', lty=2, lwd=2); abline(v=c(8, 23), col='darkblue', lty=2, lwd=2)
legend(x=10, y=0.51, bty='n', legend=c('Control Group', 'Treatment Group'), pch=22, pt.cex=2, pt.bg=c('gray40', 'green'), cex=1.2)
legend(x=9.5, y=0.46, bty='n', legend=c('Baseline Period','Treatment Period'), lty=1:2, lwd=2, cex=1.2)
text.height = 0.03; text.size = 0.95
text(x=5, y=text.height, labels='Night Rate \n(9-12�)', col='darkblue', cex=text.size)
text(x=24.1, y=text.height, labels='Night \n Rate \n(9-12�)', col='darkblue', cex=text.size)
text(x=13, y=text.height, labels='Day Rate \n(12.5-14�)', col='darkorange', cex=text.size)
text(x=21, y=text.height, labels='Day Rate \n(12.5-14�)', col='darkorange', cex=text.size)
text(x=18.05, y=text.height, labels='Peak Rate \n(20-38�)', col='red', cex=text.size)

dev.off()

dt.halfhr.avg[any.trt==1]

getsize(ls())

## Graph time series of peak consumption aggregated to daily level (Figure 4)
compute.ses = function(data) {
  summary(felm(kwh ~ 1 | 0 | 0 | id, data=data))$coefficients[,2]
}

avg.kwh = dt[period=='Peak', .(kwh=mean(kwh), se.naive=sd(kwh)/sqrt(.N), .N), by=.(date, any.trt=(tar_stim!='EE'))]

unique.dates = unique(avg.kwh[, date])
SEs.trt <- SEs.ctrl <- numeric(length(unique.dates))
# Get SEs with clustered standard errors.
for (i in 1:length(unique.dates)) {
  SEs.trt[i] = suppressWarnings(compute.ses(dt[period=='Peak'][date==unique.dates[i]][tar_stim!='EE']))
  SEs.ctrl[i] = suppressWarnings(compute.ses(dt[period=='Peak'][date==unique.dates[i]][tar_stim=='EE']))
}
SEs = rbind(data.table(any.trt=TRUE, date=unique.dates, se=SEs.trt),
            data.table(any.trt=FALSE, date=unique.dates, se=SEs.ctrl))
avg.kwh = merge(avg.kwh, SEs, by=c('any.trt','date'))
avg.kwh[, tt:=1:.N, by=any.trt]
avg.kwh[, kwh.upr := kwh + 1.96*se]
avg.kwh[, kwh.lwr := kwh - 1.96*se]

# Check correlations
cor(diff(avg.kwh[any.trt==TRUE, kwh]), diff(avg.kwh[any.trt==FALSE, kwh]))
cor(diff(avg.kwh[any.trt==TRUE & date<='2009-12-31', kwh]), diff(avg.kwh[any.trt==FALSE & date<='2009-12-31', kwh]))

## Graph daily consumption by treatment/control groups
pdf(graph_dir%&%'Daily_Consumption_'%&%today()%&%'.pdf', family='serif', width=10, height=7)
par(mai=c(1,1,0.5,0.4))
plot(kwh ~ date, data=avg.kwh[any.trt==TRUE], col='green', type='l', lwd=2, ylim=c(0,0.7),
     xlab='Date', ylab='Peak Period Consumption (kWh per 30 minutes)', xaxt='n',
     cex.lab=1.5, cex.axis=1.5)
abline(v=seq.Date(as.Date('2009-07-01'), as.Date('2011-01-01'), by='months'), lty=3, col='gray70')
axis.Date(1, at=seq.Date(as.Date('2009-07-01'), as.Date('2011-01-01'), by='months'), format='%b-%Y', cex.axis=1.5)
grid(col='gray70', nx=NA, ny=NULL)
abline(v=as.Date('2010-01-01'), h=0)
polygon(x=c(avg.kwh[any.trt==TRUE, date], rev(avg.kwh[any.trt==TRUE, date])), 
        y=c(avg.kwh[any.trt==TRUE, kwh.upr], rev(avg.kwh[any.trt==TRUE, kwh.lwr])), 
        col=rgb(0,1,0,0.25), border=NA)
lines(kwh ~ date, data=avg.kwh[any.trt==FALSE], col='gray40', lwd=2)
polygon(x=c(avg.kwh[any.trt==FALSE, date], rev(avg.kwh[any.trt==FALSE, date])), 
        y=c(avg.kwh[any.trt==FALSE, kwh.upr], rev(avg.kwh[any.trt==FALSE, kwh.lwr])), 
        col=rgb(0,0,0,0.25), border=NA)
legend(x=as.Date('2010-02-01'), y=0.7, legend=c('Control Group','Treatment Group'), col=c('gray70','green'), lty=1, lwd=3, cex=1.5, bty='n')
dev.off()

#### Illustrate and estimate half-hourly treatment effects -----
### Half Hourly Treatment Effects ----
# We will do it separately for weekdays (when TOU pricing was in effect)
# and weekend/holidays (when TOU pricing was not in effect)
# Create array to hold results (48 half hours by 2 (weekday/not))
fe.fit.hhour.summaries = array(list(), dim=c(48, 2), dimnames=list(1:48, c('Weekdays','Weekend/Holidays')))
start = Sys.time()
maxiter = 48*2
iter = 1
gc()

dt[, wk:=year(date)%&%'-'%&%week(date)]
dt.nonweekday[, wk:=year(date)%&%'-'%&%week(date)]
for (w in 1:2) {
  for (i in 1:48) {
    if (w==1) {
      fe.fit = felm(log(kwh+0.0001) ~ any.trt:trt.period + trt.period |
                      id+wk | 0 | id+wk, data=dt[half.hr==i])
    }
    if (w==2) {
      fe.fit = felm(log(kwh+0.0001) ~ any.trt:trt.period + trt.period | 
                      id+wk | 0 | id+wk, data=dt.nonweekday[half.hr==i])
    }
    fe.fit.hhour.summaries[[i]][[w]] = summary(fe.fit)
    fe.fit.hhour.summaries[[i]][[w]]$residuals <- NULL
    fe.fit.hhour.summaries[[i]][[w]]$fe <- NULL
    rm(fe.fit)
    gc()
    # Post current status and completion time
    elap.time = Sys.time() - start
    time.per.iter = elap.time/iter
    time.left = (maxiter - iter)*time.per.iter
    est.end.time = Sys.time() + time.left
    message(iter %&% ' out of '%&%maxiter%&%' ('%&%
              as.character(round(iter/maxiter,3)*100)%&%
              '%) complete. Estimated time remaining: '%&% 
              round(time.left, 1) %&% ' minutes, ('%&%est.end.time%&%')')
    iter = iter+1
  }
}
end = Sys.time()
difftime(end, start)
lapply(fe.fit.hhour.summaries, function(x) x$N) # check sample sizes

# Plot treatment affects by half hour
trt.effects = array(NA, dim=c(48, 2), dimnames=list(1:48, c('Weekdays','Weekend/Holidays')))
trt.effects.lwr = trt.effects.upr = trt.effects.pval <- trt.effects

for (w in 1:2) {
  for (t in 1:48) {
    # Store a list of all 12 sets of results (for all 16 trts)
    coefs.temp = coef(fe.fit.hhour.summaries[[t]][[w]])
    row.num = which(rownames(coefs.temp)=='any.trt:trt.period')
    trt.effects[t,w] = coefs.temp[row.num, 1]
    trt.effects.lwr[t,w] = coefs.temp[row.num, 1] - 1.96*coefs.temp[row.num, 2]
    trt.effects.upr[t,w] = coefs.temp[row.num, 1] + 1.96*coefs.temp[row.num, 2]
    trt.effects.pval[t,w] = coefs.temp[row.num, 4]
  }
}

## Plot half-hourly effects (Figure 2 and Figure A.6)
for (w in 1:2) {
  if (w==1) fig.name = graph_dir%&%'/Hourly-Treatment-Effects-Weekdays-'%&%today()%&%'.pdf'
  if (w==2) fig.name = graph_dir%&%'/Hourly-Treatment-Effects-Holidays-and-Weekends-'%&%today()%&%'.pdf'
  
  if (write.out==TRUE) pdf(fig.name, family='serif', width=10, height=7)
  par(mai=c(0.9,0.9,0.05,0.15), family='serif') # for pdf output
  hour.temp = (1:48)/2-0.25 # shift left 15 minutes (0.25) so the points land in the middle of the time window

  yrange = c(-0.2, 0.2)
  ygrid = round(seq(yrange[1],yrange[2],by=0.05),2)
  signif.level = 0.05
  
  plot(trt.effects[,w] ~ hour.temp, type='n',lwd=1, pch=1, lty=1, xaxt='n', yaxt='n', col='black', xlim=c(0,24.5),
       ylim=yrange, xlab='Hour of Day', ylab=expression(paste('Treatment Effect (Log Points: ', Delta,'ln(',Y[i],'))')), cex.lab=1.5)
  axis(1, at=seq(0,24,by=2), cex.axis=1.2, font=2); abline(v=seq(0,24,by=2), lty=3, col='gray70'); abline(h=0, lty=1, col='gray20')
  axis(2, at=ygrid, cex.axis=1.2, font=2); abline(h=ygrid, lty=3, col='gray70')
  grid(col='gray70', nx=NA, ny=NULL); abline(v=seq(0,24,by=2), lty=3, col='gray70')
  abline(v=c(17, 19), col='red', lty=2, lwd=2); abline(v=c(8, 23), col='darkblue', lty=2, lwd=2)
  
  segments(x0=hour.temp, y0=trt.effects.lwr[,w], y1=trt.effects.upr[,w], lty=1, lwd=1, col='black', cex=1.4)
  arrows(x0=hour.temp, y0=trt.effects.lwr[,w], y1=trt.effects.upr[,w], lty=1, lwd=1, col='black', angle=90, code=3, length=0.05, cex=1.4)
  
  plot.values.temp = trt.effects[,w]
  plot.values.temp[trt.effects.pval[,w]>signif.level] <- NA
  points(plot.values.temp ~ hour.temp, lwd=1, pch=19, col='green', lty=1, cex=1.4)
  points(trt.effects[,w] ~ hour.temp, type='p',lwd=1, pch=1, lty=1, cex=1.4)
  
  legend(x=-0.8, y=yrange[2], legend=c('Point Estimate (Insigificant)', 'Point Estimate (Sigificant)'), lty=c(1,1), pch=c(NA,NA), col=c('black','black'), cex=1.2, bg='white', bty='n', text.font=2)
  legend(x=-0.8+0.165, y=yrange[2]-0.00, legend=c('',''), lty=c(0,0), pch=c(1,19), col=c('black','green'), cex=1.2, bg='white', bty='n')
  
  text.height = yrange[1]+0.02; 
  text.size = 1.5
  text(x=4, y=text.height, labels='Night', col='darkblue', cex=text.size, font=2)
  text(x=24.2, y=text.height, labels='Night', col='darkblue', cex=text.size, font=2)
  text(x=13, y=text.height, labels='Day', col='darkgreen', cex=text.size, font=2)
  text(x=21, y=text.height, labels='Day', col='darkgreen', cex=text.size, font=2)
  text(x=18.05, y=text.height, labels='Peak', col='red', cex=text.size, font=2)
  
  dev.off()
}


## Aggregate and merge consumption and survey data ----
dt.reg = dt[, .(kwh=mean(kwh)), by=.(any.trt, trt.period, id, rate, period)]
# drop full high resolution panel to save memory. 
# When needed later, we can reread them from .Rdata files saved above
# i.e., for dt, load(file=output_dir%&%'/consumption_panel.Rdata')
# and for dt.nonweekday, load(output_dir%&%'/consumption_panel_nonweekday.Rdata')
rm(dt); rm(dt.nonweekday); gc() 

dt.reg = dcast(dt.reg, id + any.trt + trt.period ~ period , value.var=c('kwh', 'rate'), sep='.')

# Merge tariff data in
dt.reg = merge(dt.reg, survey[, .(id, tariff)], by='id', all.y=FALSE)

# Convert into an N-row data table with change in kwh (baseline to treatment periods) for each household
dt.reg = dt.reg[order(id, trt.period)]
dt.reg[trt.period==1, trt.period2 := 'trt.period']
dt.reg[trt.period==0, trt.period2 := 'base']
dt.reg = dcast(dt.reg, id + any.trt ~ trt.period2, 
               value.var=c('kwh.Peak', 'kwh.Night', 'kwh.Day', 'rate.Peak', 'rate.Night', 'rate.Day'))

# Percentage change in consumption, separately by period of day
dt.reg[, d.kwh.Peak := (kwh.Peak_trt.period - kwh.Peak_base)/kwh.Peak_base]
dt.reg[, d.kwh.Night := (kwh.Night_trt.period - kwh.Night_base)/kwh.Night_base]
dt.reg[, d.kwh.Day := (kwh.Day_trt.period - kwh.Day_base)/kwh.Day_base]

# Merge survey data in
dt.reg = merge(dt.reg, survey[, -c('kwh','peak.kwh'), with=FALSE], by='id')

# Average treatment effect: -8.9%
summary(lm(d.kwh.Peak ~ any.trt, data=dt.reg))
dt.reg[, .N, by=any.trt]

# Turn subtreatment variables (tariff group, info group) into indicators
# Expand control observations so we can split on treatment types
# The treatment variables are tariff and simulus
dum.tariff = dummy(dt.reg[, tariff])
colnames(dum.tariff) = gsub('dt.reg','tariff.', colnames(dum.tariff))
dum.tariff = dum.tariff[,colnames(dum.tariff)!='tariff.E']
head(dum.tariff)

dum.stimulus = dummy(dt.reg[, stimulus])
colnames(dum.stimulus) = gsub('dt.reg','stimulus.', colnames(dum.stimulus))
colnames(dum.stimulus)[colnames(dum.stimulus)=='stimulus.1'] = 'stim.bm.bill'
colnames(dum.stimulus)[colnames(dum.stimulus)=='stimulus.2'] = 'stim.mon.bill'
colnames(dum.stimulus)[colnames(dum.stimulus)=='stimulus.3'] = 'stim.bm.bill.ihd'
colnames(dum.stimulus)[colnames(dum.stimulus)=='stimulus.4'] = 'stim.bm.bill.olr'
dum.stimulus = dum.stimulus[,colnames(dum.stimulus)!='stimulus.E']
head(dum.stimulus)

dt.reg = cbind(dt.reg, dum.tariff, dum.stimulus)

# Remove excess variables
dt.reg = dt.reg[, -c('tar_stim','tariff','stimulus','d_tariff','d_stim','d_ihd', 'd_olr_group','d_monthly_bill',
                     'd_bimonthly_bill','d_group','d_treatment','d_grp_none', 'tariff'), with=FALSE]
dt.reg = dt.reg[, -c('kwh.Peak_trt.period','kwh.Night_trt.period','kwh.Day_trt.period'), with=FALSE]
dt.reg = dt.reg[, -c('rate.Peak_base','rate.Night_base','rate.Day_base'), with=FALSE]

# Designate cross validation groups.
# Do this manually so that duplicated observations (implemented below, as part of Athey-Imbens
# extension) are always included together. I.e., household 1's pseudo-observations
# are always included together in the same CV group. This avoids "peeking" at the 
# test set by including the same household in both the training and test sets.
# Do this before duplicating so that the cv.grp sticks around in the duplicated observations
K = 10 # folds
control_idx <- which(dt.reg[, any.trt] == 0)
treat_idx <- which(dt.reg[, any.trt] == 1)
xgroups <- rep(0, dt.reg[, .N])
set.seed(1)
xgroups[control_idx] <- sample(rep(1:K, length = length(control_idx)), length(control_idx), replace = F)
xgroups[treat_idx] <- sample(rep(1:K, length = length(treat_idx)), length(treat_idx), replace = F)
dt.reg[, cv.grps:=xgroups]

## 3) Balance Tests  ----
# Function for checking completeness
check.nas = function(DF) sapply(DF, function(x) sum(!is.na(x)))
# Use variables with complete data to ensure we're using the same sample everywhere.
balance.vars = names(which(check.nas(dt.reg)==dt.reg[,.N]))

# Drop ex post variables and outcome variables
drop.vars = grep('^stim',names(dt.reg), v=T)
drop.vars = c(drop.vars, balance.vars[(which(balance.vars %in% names(survey.post)))])
drop.vars = c(drop.vars, grep('tariff',names(dt.reg), v=T))
drop.vars = c(drop.vars, grep('_post',names(dt.reg), v=T))
drop.vars = c(drop.vars, grep('d_installed',names(dt.reg), v=T))
drop.vars = c(drop.vars, 'cv.grps', 'id','d.kwh.Peak', 'd.kwh.Night', 'd.kwh.Day')
drop.vars = c(drop.vars, 'rate.Peak_trt.period', 'rate.Night_trt.period', 'rate.Day_trt.period')

balance.vars = balance.vars[!(balance.vars %in% drop.vars)]

dt.balance = dt.reg[, balance.vars, with=FALSE]

## LPM of Balance
lpm.fit = lm(any.trt ~ . , data=dt.balance)
lpm.sum = summary(lpm.fit)
lpm.coefs = lpm.sum$coefficients
lpm.coefs = lpm.coefs[order(lpm.coefs[, 4]),]
lpm.coefs[grep('kwh', rownames(lpm.coefs)),]
which(lpm.coefs[,4]<0.05)
dim(lpm.coefs)

which(lpm.coefs[,4]<0.05/(nrow(lpm.coefs)-1)) # only the intercept is significant with Bonferroni correction
lpm.coefs[lpm.coefs[,4]<0.05/nrow(lpm.coefs[-1,]),] # don't count intercept in multiple testing

## Logit of Balance
# Also run logit to check it gives similar results
logit.fit = glm(any.trt ~ . , data=dt.balance, family='binomial')
# logit.fit = glm(any.trt ~ d_has_kids + n_appliances + n_elec_shower , data=dt.logit, family='binomial')
logit.sum = summary(logit.fit)
logit.coefs = logit.sum$coefficients
# logit.coefs = logit.coefs[-1,] # drop intercept
# logit.coefs = logit.coefs[!is.na(logit.coefs[,1]),] # drop excluded dummies (appearing as NAs)
logit.coefs = logit.coefs[order(logit.coefs[, 4]),]

which(logit.coefs[,4]<0.05/nrow(logit.coefs-1)) # nothing is significant with Bonferroni correction
logit.coefs[logit.coefs[,4]<0.05/nrow(logit.coefs-1),] # don't count intercept in multiple testing

# Check they give the same signfiicant coefficients
names(which(lpm.coefs[,4]<0.05)); names(which(logit.coefs[,4]<0.05))

lpm.coefs[which(rownames(lpm.coefs) %in% keep.vars),]
range(predict(lpm.fit)) # ensure all between 0 and 1
quantile(predict(lpm.fit), probs=(1:99)/100) # ensure all between 0 and 1

# Choose which of the 100+ covariates to keep for the table
keep.vars = grep('Intercept|kwh.|n_adult|Retired|tv_large|65+|AB|Primary|internet|desktop|cook_type.Elec',
                 names(lpm.fit$coefficients), v=T) # keep kwh vars
keep.vars = c(keep.vars, names(which(lpm.coefs[,4]<0.05))) # ensure we keep keep significant ones
keep.vars = keep.vars[which(!is.na(lpm.fit$coefficients[keep.vars]))]
keep.vars = unique(keep.vars) # elimintate duplicates

# Output LPM (Table 2)
# Output Table 2
if (write.out==TRUE) stargazer(lpm.fit, out=output_dir%&%'/Tables/Table-Balance-LPM '%&%today()%&%'.tex', 
                               title='Linear Probability Model of Treatment on Covariates', align=TRUE, digits=2, 
                               dep.var.labels = 'Treated (Indicator)', omit.stat=c('ser'), no.space=TRUE, 
                               keep=keep.vars) 
lpm.coefs
lpm.coefs[grep('class', rownames(lpm.coefs)),]; grep('class', names(dt.balance), v=T)
dim(lpm.coefs)[1] - 1 # number of covariates (minus 1 for intercept)
dim(lpm.coefs)[1] - length(keep.vars) # number of covariates not shown

sum(lpm.coefs[, 'Pr(>|t|)']<0.05)/(nrow(lpm.coefs)-1) # 1.8% significant
 
which(lpm.coefs[,4]<0.05) # nothing is significant with Bonferroni correction
which(lpm.coefs[,4]<0.05/nrow(lpm.coefs-1)) # nothing is significant with Bonferroni correction
lpm.coefs[lpm.coefs[,4]<0.05/nrow(lpm.coefs-1),] # don't count intercept in multiple testing

## Balance on Consumption, by subtreatment (Table 3)
dt.kwh = merge(dt.reg[,.(id, kwh.Peak_base)], dt_assign, by='id')

summary(lm(kwh.Peak_base ~ tar_stim!='EE', data=dt.kwh)) # All treatment 0.019*
bal.tar = summary(lm(kwh.Peak_base ~ relevel(factor(tariff),'E'), data=dt.kwh))
bal.stim = summary(lm(kwh.Peak_base ~ relevel(factor(stimulus),'E'), data=dt.kwh))
bal.tar_stim = summary(lm(kwh.Peak_base ~ relevel(factor(tar_stim),'EE'), data=dt.kwh))
bal.tar_stim

write.xlsx(bal.tar$coefficients, sheetName='tariff',
           file=output_dir%&%'/Tables/Balance-by-tariff_'%&%today()%&%'.xlsx')
write.xlsx(bal.stim$coefficients, append=TRUE, sheetName='stim',
           file=output_dir%&%'/Tables/Balance-by-tariff_'%&%today()%&%'.xlsx')
write.xlsx(bal.tar_stim$coefficients, append=TRUE, sheetName='tar_stim',
           file=output_dir%&%'/Tables/Balance-by-tariff_'%&%today()%&%'.xlsx')
summary(lm(kwh.Peak_base ~ tar_stim!='EE', data=dt.kwh))

# Construct propensity scores (used later, for robustness check)
prop.score = predict(logit.fit, type = 'response')
summary(prop.score)
plot(density(prop.score - mean(dt.reg[, any.trt]), adj=2)); abline(v=0, h=0)
dt.reg[, prop.score := prop.score]

#### 4) Causal Tree Analysis -----

## Implement data preprocessing to allow tree to split on subtreatment. ----
# Issue: causalTree will NEVER split on ihd or monthly bill because that will mean one leaf will have 0 treatment obs
# Resolution: We will duplicate the control group data M times, where M is the number of treatment
# groups. For each duplicate, set the m'th treatment dummy=1, and the rest=0 (for t=1,2...,M).
# We will then have N*M control variables, instead of the original M. The average of the N*M
# observations will be equal to the average of the original N. And when we split on a control group
# m, 1/M'th of the control observations will go into the "yes" category and (M-1)/M of them will
# go into the "no" category.

expand.controls = function(DT, trt, trt.vars) {
        # Inputs in a data table, a treatment/control index vector, and a vector of character names
        # Outputs a data table with the control observations duplicated T times (for a total of T+1)
        # control observations for each 1 original control (where T=length(trt.vars)). Each of the new
        # observations will have exactly 1 trt.var==1 and the rest 0.
        DT.trt = DT[trt==1]
        DT.trt[, duped.obs:=0]
        
        DT.ctrl = DT[trt==0]
        DT.ctrl[, duped.obs:=0]
        
        # Initialize control data
        DT.out = copy(DT.ctrl)
        
        for (i in 1:length(trt.vars)){
                # Make copy of control data
                x = copy(DT.ctrl)
                
                # Set the i'th treatment variable to 1 (the rest should already be zero)
                x[, trt.vars[i]:=1]
                x[, duped.obs:=1]
                
                # Add it to output
                DT.out = rbind(DT.out, x)
        }
        # Drop original observations
        DT.out = DT.out[duped.obs==1]
        
        # Include treatment data as well
        DT.out = rbind(DT.trt, DT.out)
        DT.out[, duped.obs:=NULL]
        return(DT.out)
}

# Construct Expanded Dataset with duplicated control observations
dt.reg.dupe = expand.controls(dt.reg, trt = dt.reg[, any.trt], trt.vars = c(colnames(dum.stimulus), colnames(dum.tariff)))
dt.reg.dupe

dt.reg.dupe = dt.reg.dupe[order(id)]

# Save number of a treatment groups
M = dt.reg.dupe[any.trt==0, .N]/dt.reg[any.trt==0, .N]
# dup.weights = dt.reg.dupe[,ifelse(any.trt==0, 1/M, 1)] 
# note: using dup.weights causes 2 problems:
# 1) the weights are wrong after any split on treatment group. in particular, after a split on 
#    treatment, the effective "M" (number of treatment groups) is decremented by 1 in each child node.
#    so the weights need to vary by node, which is not possible here.
# 2) it changes the initial splits, which are on non-treatment dummies in each case.
#    with weights, it splits on basleine kwh before awareness. without, it reverses this order.
#    the order shouldn't matters.
# However, simulations suggest that the weighting doesn't matter in large samples 
# (either for false positives or false negatives). Because of the conceptual problem of using it,
# I don't use it. However, I did run sensitivities and the main result (regarding awareness & baseline consumption)
# is the same. The treatment groups don't split with the weights however.

# Check that the treatment/control group means are the same as before duplicating
sapply(dt.reg.dupe[any.trt==0, .(id, d_has_kids, n_adults)], mean)
sapply(dt.reg[any.trt==0, .(id, d_has_kids, n_adults)], mean)
sapply(dt.reg.dupe[any.trt==1, .(id, d_has_kids, n_adults)], mean)
sapply(dt.reg[any.trt==1, .(id, d_has_kids, n_adults)], mean)
dt.reg.dupe[, table(any.trt)]; dt.reg[, table(any.trt)]

#### First, create some helper functions for handling the tree analysis ----

# Create function that will fix the proportions output by rpart.plot() to 
# reflect treatment # of observations
replace.proportions = function(tree, data) {
        # Takes as input a tree object and the dataset that it was fit on, alongside the treatment vector
        # Returns the same tree object, but with the $frame$wt column changed to reflect the number
        # of treatment observations (not treatment+control).
        
        # Function to find all parents of a leaf.
        parent <- function(x) {
                if (x[1] != 1)
                        c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
        }
        parent.trap <- function(x) {
                z = parent(x)
                return(z[-length(z)]) # drop the last one (the original input)
        }
        
        
        # Fix leaves to be number of treated obs in leaves
        library(treeClust)
        leaf.wt = table(rpart.predict.leaves(tree, newdata=data[any.trt==1])) # number of treat observations by terminal node
        
        tree$frame[as.numeric(names(leaf.wt)), 'wt'] = leaf.wt # overwrite
        tree$frame[as.numeric(names(leaf.wt)), 'n_trt'] = leaf.wt 
        # tree$frame[as.numeric(names(leaf.wt)), 'n_ctrl'] = table(tree$where[data[, any.trt==0]])/dupes.per.ctrl 
        
        # Fix parent nodes to be number of treated obs in nodes
        # Find leaves
        leaves = tree$frame[tree$frame$var=='<leaf>',]
        
        # Find nodes & leaves
        nodes = copy(tree$frame)
        
        # Find the leaf numbers
        leaf.nos = as.numeric(rownames(leaves))
        leaf.nos = sort(leaf.nos)
        
        # Find the parents associated with each leaf
        parent.nos = lapply(leaf.nos, parent.trap)
        names(parent.nos) = leaf.nos
        parent.nos
        
        # For each parent, add up the number of treatment observations contained in each leaf
        for (node in unique(unlist(parent.nos))) {
                node.leaves = c()
                for (lf in names(parent.nos)) {
                        if ((node %in% parent.nos[[lf]])==TRUE) {
                          node.leaves = c(node.leaves, lf)
                        }
                }
                
                nodes[rownames(nodes)==as.character(node), "wt"] = sum(leaves[node.leaves, "wt"])
                nodes[rownames(nodes)==as.character(node), "n_trt"] = sum(leaves[node.leaves, "n_trt"])
                # nodes[rownames(nodes)==as.character(node), "n_ctrl"] = sum(leaves[node.leaves, "n_ctrl"])
        }
        tree$frame = nodes
        
        return(tree)
}

# Helper function that automatically plots the optimally-pruned tree
plot.pruned.tree = function(tree, nsplits=NULL, main=tree$call$split.Rule, data, digits=2) {
        tree = replace.proportions(tree, data=data)
        if (is.null(nsplits)) {
                opcp <- tree$cptable[, 1][which.min(tree$cptable[,'xerror'])]
        } else {
                opcp <- tree$cptable[, 1][which.min( abs(tree$cptable[, 'nsplit']-nsplits) )]
        }
        pruned.tree = prune(tree, cp=opcp)
        rpart.plot(pruned.tree, main=main, digits=digits)
}

# Function to compute standard errors for each node of tree.
tree.SEs = function(tree, data, use.prop.score=FALSE, yvar) {
        # Steps to compute standard errors for each node
        # 1) For each number of splits possible (0, 1, 2, ...), prune the tree. 
        # This lets us get the intermediate nodes as simply leaves of a smaller tree.
        # 2) For each pruned tree, find the observations falling in each leaf. Drop any 
        # duplicated control observations.
        # 3) For each leaf, run a regression to estimate the treatment effect and get SEs
        # 4) Return the regressions on each leaf as a list object
        
        # Set up frames to store results
        node.nums = rownames(tree$frame)
        cps = tree$cptable[, 'CP']
        
        res = list()
        for (i in 1:length(node.nums)) res[[i]] = NA
        names(res) = node.nums
        
        # Set up formula based on the given y variable name
        form = as.formula(paste0(yvar, '~ any.trt'))
        
        # For each complexity parameter (each corresponding to a different number of splits)
        for (i in 1:length(cps)) {
                data.temp = copy(data)
                
                # Prune the tree according to each CP
                pruned.tree = prune(tree, cp=cps[i])
                
                # Find the leaf numbers
                leaves = pruned.tree$frame[pruned.tree$frame$var=='<leaf>',]
                leaf.nos = as.numeric(rownames(leaves))
                leaf.nos = sort(leaf.nos)
                
                # Find predicted treatment effects with this tree.
                tree.pred = predict(pruned.tree, newdata=data)
                
                # Find the leaf where each observation falls.
                nodes = cbind(as.numeric(rownames(pruned.tree$frame)), pruned.tree$frame$yval)
                dt.where = merge(data.table(TE=tree.pred, n=1:length(tree.pred)), data.table(node=nodes[,1], TE=nodes[,2]), by='TE', all.x=TRUE)
                dt.where = dt.where[order(n)][, .(TE, node)]
                
                # Find the leaf location of each observation
                data.temp[, node:=dt.where[, node]]
                
                # For each leaf, run the regression only on the observations in that leaf
                for (l in 1:length(leaf.nos)) {
                        data.reg = data.temp[node==leaf.nos[l]][!duplicated(id)] # drop any duplicated id's
                        
                        # Set up weights (if necessary)
                        wt = rep(1, times=data.reg[, .N])
                        if(use.prop.score) wt = 1/data.reg[, prop.score]
                        
                        res[[as.character(leaf.nos[l])]] = summary(lm(form, data=data.reg,
                                                                      weights=wt))
                }
        }
        
        # For each fitted regression, retrieve the standard error and place it in tree$frame
        # as a new column
        for (i in 1:length(res)) if(!is.na(res[[i]][1])) tree$frame[i, 'SE'] = res[[i]]$coefficients['any.trt','Std. Error']
        tree$frame[,'z'] = tree$frame[,'yval']/tree$frame[,'SE']
        tree$frame[,'pval'] = round(2*pnorm(abs(tree$frame[,'z']), 0, 1, lower.tail=FALSE), 4)
        return(list(tree, res))
}

## Run causal tree algorithm ----
# Double check that CV groups are constant within id. Should be 1 obs for treatment, and T+1 for control (e.g., 9)
dt.reg.dupe[any.trt==0][id %in% sample(unique(id),10), table(cv.grps, id)]
dt.reg.dupe[any.trt==1][id %in% sample(unique(id),10), table(cv.grps, id)]

cv.grps = dt.reg.dupe[, cv.grps]
prop.score.dupe = dt.reg.dupe[, prop.score]

# Remove irrelevant variables
dt.tree = dt.reg.dupe[, -c('id', 'any.trt', 'cv.grps', 'prop.score',
                           'rate.Peak_trt.period', 'rate.Day_trt.period', 
                           'rate.Night_trt.period'), with=FALSE]
names(dt.tree)
trt = dt.reg.dupe[, any.trt]

# Peak period tree
c.tree.peak = causalTree(d.kwh.Peak ~ . - d.kwh.Day - kwh.Day_base - d.kwh.Night - kwh.Night_base, # - kwh.Peak_base,  
                  data=dt.tree, split.Rule='TOT', cv.option='TOT',
                  treatment = trt, split.Honest = F, cv.Honest = F, split.Bucket=F, 
                  xval=cv.grps, minsize=0.02*dt.reg[, .N])
# Check pruned tree results
plot.pruned.tree(replace.proportions(c.tree.peak, data=dt.reg.dupe), data=dt.reg.dupe)

# Repeat for nighttime
c.tree.night = causalTree(d.kwh.Night ~ . - d.kwh.Day - kwh.Day_base - d.kwh.Peak - kwh.Peak_base, 
                         data=dt.tree, split.Rule='TOT', cv.option='TOT',
                         treatment = trt, split.Honest = F, cv.Honest = F, split.Bucket=F, 
                         xval=cv.grps, minsize=0.02*dt.reg[, .N]) 
# Repeat for daytime
c.tree.day = causalTree(d.kwh.Day ~ . - d.kwh.Peak - kwh.Peak_base - d.kwh.Night - kwh.Night_base, 
                          data=dt.tree, split.Rule='TOT', cv.option='TOT',
                          treatment = trt, split.Honest = F, cv.Honest = F, split.Bucket=F, 
                          xval=cv.grps, minsize=0.02*dt.reg[, .N])

# Propensity score peak tree (sensitivity)
c.tree.peak.ps = causalTree(d.kwh.Peak ~ . - d.kwh.Day - kwh.Day_base - d.kwh.Night - kwh.Night_base, 
                            data=dt.tree, split.Rule='TOT', cv.option='TOT',
                            treatment = trt, split.Honest = F, cv.Honest = F, split.Bucket=F, 
                            xval=cv.grps, minsize=0.02*dt.reg[, .N], weights=1/prop.score.dupe)

# When plotting, adjust the sample sizes to reflect the original number of
# households, not including the pseudo-observations introduced in the 
# expand.controls() function
c.tree.peak = replace.proportions(c.tree.peak, data=dt.reg.dupe)
c.tree.peak.ps = replace.proportions(c.tree.peak.ps, data=dt.reg.dupe)
c.tree.night = replace.proportions(c.tree.night, data=dt.reg.dupe)
c.tree.day = replace.proportions(c.tree.day, data=dt.reg.dupe)

## Compare trees before pruning

# Compare regular peak tree to peak propensity tree
par(mfrow=c(1,2), mai=c(1,1,0.5,0.5))
plot.pruned.tree(c.tree.peak, data=dt.reg.dupe, nsplits=6, digits=3, main='Peak Hours')
plot.pruned.tree(c.tree.peak.ps, data=dt.reg.dupe, nsplits=6, digits=3, main='Peak Hours')

# Compare peak, night, and daytime trees
par(mfrow=c(1,3), mai=c(1,1,0.5,0.5))
plot.pruned.tree(c.tree.peak, data=dt.reg.dupe, nsplits=6, digits=3, main='Peak Hours')
plot.pruned.tree(c.tree.night, data=dt.reg.dupe, nsplits=6, digits=3, main='Night Hours')
plot.pruned.tree(c.tree.day, data=dt.reg.dupe, nsplits=6, digits=3, main='Day Hours')

## And with optimal pruning
# Regular Peak vs peak propensity
par(mfrow=c(1,2), mai=c(1,1,0.5,0.5))
plot.pruned.tree(c.tree.peak, data=dt.reg.dupe, nsplits=NULL, digits=3, main='Peak Hours')
plot.pruned.tree(c.tree.peak.ps, data=dt.reg.dupe, nsplits=NULL, digits=3, main='Peak Hours')

# Peak, night, and daytime trees
par(mfrow=c(1,3), mai=c(1,1,0.5,0.5))
plot.pruned.tree(c.tree.peak, data=dt.reg.dupe, nsplits=NULL, digits=3, main='Peak Hours')
plot.pruned.tree(c.tree.night, data=dt.reg.dupe, nsplits=NULL, digits=3, main='Night Hours')
plot.pruned.tree(c.tree.day, data=dt.reg.dupe, nsplits=NULL, digits=3, main='Day Hours')

# CHeck complexity parameter graphs
plotcp(c.tree.peak); abline(h=min(c.tree.peak$cptable[, 'xerror']), col='red')
plotcp(c.tree.night); abline(h=min(c.tree.night$cptable[, 'xerror']), col='red')
plotcp(c.tree.day); abline(h=min(c.tree.day$cptable[, 'xerror']), col='red')

# Note: I have used TOT because the other methods have terrible out-of-sample fit
# (flat rel. error/CP plots). Also, simulations show that they fail to detect 
# true splits on covariates correlated with treatment.

## Next: prune trees, then calculate standard errors in each node through
# a separate regression. The function above tree.SEs does this.
# Simple function to optimally prune s tree
opt.prune = function(tree) { 
        opcp <- tree$cptable[, 1][which.min(tree$cptable[,'xerror'])]
        pruned.tree = prune(tree, cp=opcp)
        return(pruned.tree)
}

c.tree.peak = opt.prune(c.tree.peak)
c.tree.peak.ps = opt.prune(c.tree.peak.ps)
c.tree.night = opt.prune(c.tree.night)
c.tree.day = opt.prune(c.tree.day)

# Add in standard errors for each node
tree.regs.peak = tree.SEs(tree=c.tree.peak, data=dt.reg.dupe, yvar='d.kwh.Peak')
tree.regs.peak.ps = tree.SEs(tree=c.tree.peak.ps, data=dt.reg.dupe, yvar='d.kwh.Peak')
tree.regs.night = tree.SEs(tree=c.tree.night, data=dt.reg.dupe, yvar='d.kwh.Night')
tree.regs.day = tree.SEs(tree=c.tree.day, data=dt.reg.dupe, yvar='d.kwh.Day')

c.tree.peak$frame = tree.regs.peak[[1]]$frame
c.tree.peak.ps$frame = tree.regs.peak.ps[[1]]$frame
c.tree.night$frame = tree.regs.night[[1]]$frame
c.tree.day$frame = tree.regs.day[[1]]$frame

# One node was missed by tree.SEs for some reason. Fill it in by hand.
# It was node 7, which is aware & base consumption >= 0.12 & NOT IHD & NOT Monthly bill
c.tree.peak$frame
dt.temp = dt.reg[d_aware_of_tariff_change==1 &
                   kwh.Peak_base>=0.1198157 & 
                   stim.bm.bill.ihd==0 &
                   stim.mon.bill==0]
node.7.reg = summary(lm(d.kwh.Peak ~ any.trt,  data=dt.temp))
node.7.reg
c.tree.peak$frame[which(as.character(c.tree.peak$frame$var)=='stim.bm.bill.olr'),c('SE','z','pval')] <-  
  node.7.reg$coefficients['any.trt',c('Std. Error','t value','Pr(>|t|)')]

## Plot nice tree graph (Figure 5) ----
p.tree.peak = opt.prune(c.tree.peak)
# customized plotting instruction
# https://cran.r-project.org/web/packages/rpart.plot/rpart.plot.pdf
# extra=6/7/8 shows share within node. 9/10 show overall shares. add 100 to include both within-node shares and overall shares

tree = copy(p.tree.peak)

node.fun1 <- function(x, labs, digits, varlen)
{
  # paste(labs, "n="%&%x$frame$n_trt, sep='\n')
  paste("TE: "%&%as.character(round(x$frame$yval,3)*100)%&%"%", 
        "SE: ("%&%as.character(round(x$frame$SE,3)*100)%&%"%)", 
        "n: "%&%x$frame$n_trt, sep='\n')
}

split.labs <- function(x, labs, digits, varlen, faclen) {
  for (i in 1:length(labs)) labs[i] = paste(i,labs[i])
  # labs[1] = 'Aware of Tariff == 1'
  labs
}

# This returns the default labels. Use this info to build a split.labs function to
# adjust them
default.labs = rpart.plot:::internal.split.labs(tree, type=2, clip.left.labs = FALSE, clip.right.labs = FALSE, 
                                                xflip=FALSE, digits=2, varlen= 25, faclen=3, facsep=",", 
                                                eq=" = ", lt = " < ", ge = " >= ", split.prefix = "",
                                                right.split.prefix = NULL, split.suffix = "", right.split.suffix = NULL)
default.labs

dt.reg[, mean(kwh.Peak_base<0.25)]
dt.reg[, mean(kwh.Peak_base<0.12)]
a = tree$frame[12,]
b = tree$frame[13,]
tree$frame[12,] = b
tree$frame[13,] = a

split.labs <- function(x, labs, digits, varlen, faclen) {
  labs[grep('tariff_change >= 0.5', labs)] = 'Aware of Tariff Change?'
  labs[grep('kwh.Peak_base >= 0.12', labs)] = 'Baseline Average Peak \n Consumption >= 0.12 kWh \n (5th pctile)'  # = 0.1198 kwh per half hour * 2 half hours per hour
  labs[grep('stim.bm.bill.ihd >= 0.5', labs)] = 'Info. Treatment: \n In-Home \n Display?' 
  labs[grep('stim.mon.bill >= 0.5', labs)] = 'Info. Treatment: \n Monthly \n Bill?'
  labs[grep('stim.bm.bill.olr >= 0.5', labs)] = 'Overall Load Reduction Incentive?'
  labs[grep('kwh.Peak_base < 0.25', labs)] = 'Baseline Average Peak \n Consumption >= 0.25 kWh \n (24th pctile)'  # = 0.253 kwh per half hour * 2 half hours per hour
  # Note that the last one swaps the sign. This is because we swap the values of the nodes below (a,b)
  # in order for both sides of the tree to have consistent splitting (both kwh >=)
  labs
}

# Standard wrapper
prp.defaults <- function(x, type=2, extra=0, under=TRUE, branch.lty=3, box.palette="GnYlRd", 
                         node.fun=node.fun1, yesno=1, ni=TRUE, fallen.leaves=TRUE, 
                         split.fun=split.labs, varlen=0, tweak=0.95, trace=TRUE, round=1, 
                         shadow.col='darkgray', font=2, ...) {
  prp(x=x, type=type, extra=extra, under=under, branch.lty=branch.lty,
      box.palette=box.palette, node.fun=node.fun, yesno=yesno, ni=ni,
      fallen.leaves=fallen.leaves, split.fun=split.fun, varlen=varlen, 
      tweak=tweak, trace=trace, round=round, shadow.col=shadow.col, font=font, ...)
}

# Output Peak Tree graph (Figure 5)
if (write.out==TRUE) pdf(file=graph_dir%&%'/Peak-Tree-'%&%today()%&%'.pdf', family='serif', width=10, height=7)
par(mfrow=c(1,1), family='serif', mai=c(0.85,0.85,0.4,0.15)) # for pdf output
prp.defaults(tree, cex=1.15, split.cex=1.15, tweak=0.95)
dev.off()

#### 5) Theory-Driven Heterogeneity Estimation -----
# This is for Table 4
## Show theory-driven approach.
# Interact treatment effect with:
# baseline consumption (continuous and median cutoff)
# education
# income
# electric water heat
# immersion heater
# washing machine/tumble dryer

# Create a new copy of the dataframe for the theory-driven approach
dt.theory = copy(dt.reg)

# Convert income to factor
dt.theory[, table(f_hhincome, exclude=NULL)]
dt.theory[, f_hhincome:=as.character(f_hhincome)]
dt.theory[is.na(f_hhincome), f_hhincome:='Refused']
hhinc.key = c('1'='<�15k','2'='�15k - �30k','3'='�30k - 50k','4'='�50k - �75k', '5'='>�75k','Refused'='Refused')
dt.theory[, fs_hhincome:=hhinc.key[f_hhincome]]
dt.theory[, table(fs_hhincome)]

lm.theory = list()
# Just treatments

lm.theory[[1]] = lm(d.kwh.Peak ~ any.trt*scale(kwh.Peak_base), data=dt.theory)
summary(lm.theory[[1]])

lm.theory[[2]] = lm(d.kwh.Peak ~ 
                      any.trt*(scale(kwh.Peak_base)
                               + fs_education.None + fs_education.Third
                               + fs_education.Primary + I(`fs_education.Secondary to Cert` + `fs_education.Secondary without Cert`)
                               + relevel(as.factor(fs_hhincome),'Refused')), 
                    data=dt.theory)
summary(lm.theory[[2]])

lm.theory[[3]] = lm(d.kwh.Peak ~ 
                      any.trt*(scale(kwh.Peak_base)
                               + fs_education.None + fs_education.Third
                               + fs_education.Primary + I(`fs_education.Secondary to Cert` + `fs_education.Secondary without Cert`)
                               + relevel(as.factor(f_hhincome),'Refused')
                               + d_tumble_dryer
                               + d_immersion + d_dishwasher
                               + I(d_waterheater_elec_instant + d_waterheater_elec_immersion >0)
                               + d_waterheater_timer), 
                    data=dt.theory)
summary(lm.theory[[3]])

sapply(lm.theory, summary)
# Output Regression Table (Table 4)
if (write.out==TRUE) stargazer(lm.theory, out=output_dir%&%'/Tables/Table-Theory-Driven-Regs '%&%today()%&%'.tex', 
                               title='Theory-Driven Heterogeneity Estimation', align=TRUE, digits=2, 
                               dep.var.labels = '\\Delta Peak Period Electricity Consumption',
                               omit.stat=c('ser'), no.space=TRUE) #, report='vcs')

#### 6) Theory-Driven and LASSO-Driven Prediction of Awareness -----
## Run standard regression to predict awareness to complement LASSO approach 
# Make copy dataframe of just the treatment group
dt.aware = copy(dt.theory[any.trt==1])
# Drop endogenous, duplicative, or post-treatment data 
dt.aware = dt.aware[, -c('d.kwh.Day', 'd.kwh.Night', 'd.kwh.Peak',
                         'tariff.A', #'tariff.B', 'tariff.C', 'tariff.D', 
                         'stim.bm.bill', #'stim.mon.bill', 'stim.bm.bill.ihd', 'stim.bm.bill.olr', 
                         'cv.grps', 'prop.score',
                         'rate.Peak_trt.period', 'rate.Day_trt.period', 'rate.Night_trt.period'), with=FALSE]

# Drop post-trial survey questions, since they are often duplicative of awareness,
# or don't reflect data available to policymakers
ex.ante.vars = names(dt.aware)[-(which(names(dt.aware) %in% names(survey.post)))]
ex.ante.vars = c(ex.ante.vars, 'd_aware_of_tariff_change')
dt.aware = dt.aware[, c('id',ex.ante.vars), with=FALSE] # drop ex post variables, except awareness (dep. variable)
dt.aware = dt.aware[, names(which(check.nas(dt.aware)==dt.aware[, .N])), with=FALSE] # keep only complete variables

lms.aware = list()

# Add with subtreatment
lms.aware[[1]] = lm(d_aware_of_tariff_change ~ scale(kwh.Peak_base) + 
                      (stim.mon.bill + stim.bm.bill.olr + stim.bm.bill.ihd + 
                         tariff.B + tariff.C + tariff.D), data=dt.aware)

# Base consumption interacted with subtreatment
lms.aware[[2]] = lm(d_aware_of_tariff_change ~ scale(kwh.Peak_base)*
                      (stim.mon.bill + stim.bm.bill.olr + stim.bm.bill.ihd + 
                         tariff.B + tariff.C + tariff.D), data=dt.aware)

# Add "reasonable" variables
lms.aware[[3]] = lm(d_aware_of_tariff_change ~ scale(kwh.Peak_base)*
                      (stim.mon.bill + stim.bm.bill.olr + stim.bm.bill.ihd + 
                         tariff.B + tariff.C + tariff.D) +
                      scale(n_residents) + fs_gender.Female +
                      `fs_age.18-25`+`fs_age.26-35`+`fs_age.36-45`+`fs_age.56-65`+`fs_age.65+`+`fs_age.NA`+
                      fs_social_class.AB+fs_social_class.C1+fs_social_class.C2+fs_social_class.DE+fs_social_class.Farmer+
                      fs_education.None+
                      fs_education.Primary +
                      I(`fs_education.Secondary to Cert` +
                          `fs_education.Secondary without Cert`) +
                      fs_education.Third,
                    data=dt.aware)
lapply(lms.aware, summary)

# Get predicted values
lpm.pred.chosen = list()
mycols = c(rgb(1,0,0,0.25),  rgb(0,1,0,0.25), rgb(0,0,1,0.5))
mycols = c('red','green','blue')
for (i in 1:3) {
  lpm.pred.chosen[[i]] = predict(lms.aware[[i]])
  lpm.pred.chosen[[i]] = sapply(lpm.pred.chosen[[i]], function(x) min(x,1)*100)
  hist(lpm.pred.chosen[[i]], xlim=c(0,100), add=(i>1), border=mycols[i])
}

# In the basic model using only baseline consumption, the lowest predicted 
# probability of awareness is 79%.
range(lpm.pred.chosen[[1]])
# In the richest of these three models, the lowest predicted probablity of 
# awareness is 61%.
range(lpm.pred.chosen[[3]])

### Run Lassos to select model
# Lasso the Logit/LPM
# Scale to mean 0, std dev 1, for interpretation's sake
dt.aware[, kwh.Peak_base := scale(kwh.Peak_base)]

# Lasso algorithm requires dummies/interactions be implemented as indicator variables
# Create interactions for tariffs
dt.aware[, kwh.Peak_base_tarB := kwh.Peak_base*tariff.B]
dt.aware[, kwh.Peak_base_tarC := kwh.Peak_base*tariff.C]
dt.aware[, kwh.Peak_base_tarD := kwh.Peak_base*tariff.D]

# Create interactions for stims
dt.aware[, kwh.Peak_base_stim2 := kwh.Peak_base*stim.mon.bill]
dt.aware[, kwh.Peak_base_stim3 := kwh.Peak_base*stim.bm.bill.olr]
dt.aware[, kwh.Peak_base_stim4 := kwh.Peak_base*stim.bm.bill.ihd]

dt.aware = add.dummy(dt.aware , 'fs_hhincome')
dt.aware[, f_hhincome:=NULL]

# Drop id and any.trt so they're not used as independent variables
dt.aware = dt.aware[, -c('id', 'any.trt'), with=FALSE]

# LPM of Awareness
aware.lpm = lm(d_aware_of_tariff_change ~ . , data=dt.aware)

aware.sum = summary(aware.lpm)

aware.coefs = aware.sum$coefficients
aware.coefs[order(aware.coefs[,4], decreasing=TRUE),]
mean(aware.coefs[,4]<0.05)

# Due to space limitations, only a subset of the ~150 included variables in column (1) are shown.
sum(!is.na(aware.coefs[,4]))
sum(aware.coefs[,4]<0.05)
dim(aware.coefs)
aware.coefs[aware.coefs[,4]<0.05,]

# Bonferroni correction
which(aware.coefs[,4]<0.05/nrow(aware.coefs[-1,])) # only the intercept is significant with Bonferroni correction
aware.coefs[aware.coefs[-1,4]<0.05/nrow(aware.coefs[-1,]),] # don't count intercept in multiple testing

# Lasso of LPM
library(glmnet)
# Drop dependent variable and collinear dummies (otherwise the matrix can't invert)
drop.vars = c('d_aware_of_tariff_change','fs_gender_post.Female','fs_education.Third')
X.class = as.matrix(dt.aware[, -drop.vars, with=FALSE])
dim(X.class)
aware.lasso = glmnet(x=X.class, y=dt.aware[,d_aware_of_tariff_change],
                         alpha=1) #, family='binomial')
plot(aware.lasso, xvar='dev')

set.seed(1)
aware.lasso.cv = cv.glmnet(x=X.class, y=dt.aware[,d_aware_of_tariff_change],
                               alpha=1, nfolds=10) #, family='binomial')

plot(aware.lasso.cv)
lam.min = aware.lasso.cv$lambda.min
lam.1se = aware.lasso.cv$lambda.1se

lam.min.idx = which(aware.lasso$lambda==lam.min)
lasso.beta.min = aware.lasso$beta[,lam.min.idx]
lasso.beta.min = lasso.beta.min[lasso.beta.min>0]
lasso.beta.min

lam.1se.idx = which(aware.lasso$lambda==lam.1se)
lasso.beta.1se = aware.lasso$beta[,lam.1se.idx]
lasso.beta.1se = lasso.beta.1se[lasso.beta.1se>0]
lasso.beta.1se

# Run LASSO post-selection regressions
aware.vars = names(lasso.beta.min)
# Some variables have special characters, and need to be surrounded in
# quotes (`like this`) for the regression to run
aware.vars[aware.vars=='fs_own_or_rent.Own with mortgage'] = '`fs_own_or_rent.Own with mortgage`'
aware.vars[aware.vars=='fs_hhincome.<�15k'] = '`fs_hhincome.<�15k`'
aware.vars[aware.vars=='fs_hhincome.>�75k'] = '`fs_hhincome.>�75k`'
aware.form = as.formula('d_aware_of_tariff_change ~ ' %&% paste(aware.vars, collapse=' + '))
lpm.aware.min = lm(aware.form, data=dt.aware)

aware.vars = names(lasso.beta.1se)
if (length(aware.vars)>1) aware.form = as.formula('d_aware_of_tariff_change ~ ' %&% paste(aware.vars, collapse=' + '))
if (length(aware.vars)==1) aware.form = as.formula('d_aware_of_tariff_change ~ ' %&% aware.vars)
lpm.aware.1se = lm(aware.form, data=dt.aware)

summary(lpm.aware.min)
lasso.beta.min
summary(lpm.aware.1se)
lasso.beta.1se

# Check ranges
lpm.pred = predict(aware.lpm)
lpm.pred = sapply(lpm.pred, function(x) min(x,1)*100)
mean(lpm.pred<=50)
plot(density(lpm.pred))
range(lpm.pred)
hist(lpm.pred, xlim=c(0,1.2)*100)

lpm.pred.min = predict(lpm.aware.min)
lpm.pred.min = sapply(lpm.pred.min, function(x) min(x,1)*100)
range(lpm.pred.min)
hist(lpm.pred.min, xlim=c(0,1.2)*100)

lpm.pred.1se = predict(lpm.aware.1se)
lpm.pred.1se = sapply(lpm.pred.1se, function(x) min(x,1)*100)
hist(lpm.pred.1se, xlim=c(0,1)*100)
range(lpm.pred.1se)

# Output Awareness Regression Tables (Table 5)
if (write.out==TRUE) stargazer(lms.aware, lpm.aware.min, lpm.aware.1se, out=output_dir%&%'/Tables/Table-Aware-LPMs '%&%today()%&%'.tex', 
                               title='Linear Probablity Model of Awareness on Covariates', align=TRUE, digits=2, 
                               dep.var.labels = 'Aware of Tariff Change (Indicator)', omit.stat=c('ser'), no.space=TRUE)

# How many covariates?
# Minus 1 for intercept
# sum(!is.na(aware.lpm$coefficients))-1 # 147
sum(!is.na(lms.aware[[1]]$coefficients))-1 # 7
sum(!is.na(lms.aware[[2]]$coefficients))-1 # 13
sum(!is.na(lms.aware[[3]]$coefficients))-1 # 30
sum(!is.na(lpm.aware.min$coefficients))-1 # 16
sum(!is.na(lpm.aware.1se$coefficients))-1 # 1

save.image(output_dir%&%'working_electricity_data.Rdata')

